{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Intro\n",
    "Exploring the task of sketch cleaning (also automatized linearts).\n",
    "\n",
    "Based on [Edgar Simo-Serra・Sketch Simplification](http://hi.cs.waseda.ac.jp/~esimo/en/research/sketch/)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T11:20:09.895202",
     "start_time": "2017-07-24T11:19:48.903002Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "import pdb\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import glob\n",
    "\n",
    "import os\n",
    "import sys\n",
    "from os.path import join\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import animation\n",
    "\n",
    "from keras.models import Sequential\n",
    "from keras.models import Model\n",
    "from keras.layers.core import Activation, Dense\n",
    "from keras import backend as K\n",
    "from keras import optimizers\n",
    "\n",
    "sys.path.append(join(os.getcwd(), os.pardir))\n",
    "from utils import image_processing\n",
    "\n",
    "sns.set_style(\"dark\")\n",
    "sns.set_context(\"paper\")\n",
    "\n",
    "%matplotlib notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T11:20:09.900203",
     "start_time": "2017-07-24T11:20:09.897203Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "RES_DIR = join(*[os.pardir]*2, 'data', 'sketch_dataset')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T11:20:11.170275",
     "start_time": "2017-07-24T11:20:11.153274Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def load_dataset_info(dirpath):\n",
    "    data = pd.DataFrame(columns=['dirpath', 'ref', 'sketch', 'sketch2', 'lineart'])\n",
    "    for f in glob.glob(join(dirpath, '**', \"*.jpg\"), recursive=True):\n",
    "        filepath, filename = f.rsplit('\\\\', 1)\n",
    "        img_id, category = filename.split('.')[0].split('_')\n",
    "        data.set_value(img_id, category, filename)\n",
    "        data.set_value(img_id, 'dirpath', filepath)\n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T11:20:16.791597",
     "start_time": "2017-07-24T11:20:16.708592Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "dataset_info = load_dataset_info(RES_DIR)\n",
    "dataset_info.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# load sketch\n",
    "content_image = None\n",
    "with Image.open(os.path.join(RES_DIR, 'superman.jpg')) as img:\n",
    "    img = img.resize((height, width))\n",
    "    content_image = np.asarray(img, dtype='float32')\n",
    "    plt.imshow(img.convert(mode='RGB'))\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-17T09:33:15.175366",
     "start_time": "2017-07-17T09:33:14.151307Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "f, axarr = plt.subplots(len(dataset_info), 3)\n",
    "categories = {0:'ref', 1:'sketch', 2:'lineart'}\n",
    "for row_idx, (img_id, row) in enumerate(dataset_info.iterrows()):\n",
    "    for i in range(3):\n",
    "        img = plt.imread(join(row['dirpath'], row[categories[i]]))\n",
    "        axarr[row_idx, i].imshow(img)\n",
    "        axarr[row_idx, i].axis('off')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Augmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-23T17:08:31.391540",
     "start_time": "2017-07-23T17:08:31.132525Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from skimage import data\n",
    "from skimage import transform\n",
    "from skimage import io"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-23T17:26:26.681043",
     "start_time": "2017-07-23T17:26:26.662041Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def augment(imgs_info, n, suffix_folder=\"\"):\n",
    "    for i in range(n):\n",
    "        for img_id, row in imgs_info.iterrows():\n",
    "            rotation = np.random.randint(360)\n",
    "            for cat in ['ref', 'sketch', 'sketch2', 'lineart']:\n",
    "                if not pd.isnull(row[cat]):\n",
    "                    origin_img = plt.imread(join(row['dirpath'], row[cat]))\n",
    "                    dest_img = transform.rotate(origin_img, rotation, mode='edge')\n",
    "                    filename = \"{}{}{}_{}.jpg\".format(img_id, 'gen', i, cat)\n",
    "                    io.imsave(join(row['dirpath'], suffix_folder, filename), dest_img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-23T17:26:48.721303",
     "start_time": "2017-07-23T17:26:28.615153Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "augment(dataset_info, 10, 'augmented')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-23T16:52:58.157162",
     "start_time": "2017-07-23T16:52:58.151161Z"
    }
   },
   "source": [
    "# Train (CNN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:15:34.414262",
     "start_time": "2017-07-24T13:15:34.407262Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.models import Sequential\n",
    "from keras.layers import Convolution2D, MaxPooling2D, Conv2DTranspose\n",
    "from keras.layers import Activation, Dropout, Flatten, Dense\n",
    "from keras.preprocessing import image\n",
    "from keras import optimizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T11:20:21.735880",
     "start_time": "2017-07-24T11:20:21.732880Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "img_shape = (128, 128, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:22:08.941828",
     "start_time": "2017-07-24T13:22:06.647697Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "X = image_processing.load_data(dataset_info.apply(lambda x : join(x['dirpath'], x['lineart']), axis=1).values, img_shape[:2])\n",
    "y = image_processing.load_data(dataset_info.apply(lambda x : join(x['dirpath'], x['ref']), axis=1).values, img_shape[:2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:22:09.386853",
     "start_time": "2017-07-24T13:22:09.375853Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "X_train = X#(X/255)\n",
    "y_train = y#(y/255)\n",
    "print(X_train.shape)\n",
    "print(y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:37:30.858559",
     "start_time": "2017-07-24T13:37:30.633546Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = Sequential()\n",
    "model.add(Convolution2D(32, (5, 5), strides=(2, 2), padding='same', input_shape=img_shape,\n",
    "         activation='relu'))\n",
    "model.add(Convolution2D(64, (3, 3), padding='same',\n",
    "         activation='relu'))\n",
    "\n",
    "model.add(Convolution2D(128, (5, 5), strides=(2, 2), padding='same',\n",
    "         activation='relu'))\n",
    "\n",
    "model.add(Convolution2D(256, (3, 3), padding='same',\n",
    "         activation='relu'))\n",
    "model.add(Convolution2D(256, (3, 3), padding='same',\n",
    "         activation='relu'))\n",
    "\n",
    "model.add(Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same',\n",
    "         activation='relu'))\n",
    "model.add(Convolution2D(64, (3, 3), padding='same',\n",
    "         activation='relu'))\n",
    "\n",
    "model.add(Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same',\n",
    "         activation='relu'))\n",
    "          \n",
    "model.add(Convolution2D(3, (5, 5), padding='same',\n",
    "         activation='sigmoid'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:37:31.334586",
     "start_time": "2017-07-24T13:37:31.326585Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:37:52.260783",
     "start_time": "2017-07-24T13:37:52.232781Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "optimizer = optimizers.Adam(lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:37:53.834873",
     "start_time": "2017-07-24T13:37:53.784870Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.compile(loss='mean_squared_error', optimizer=optimizer, metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:39:39.703928",
     "start_time": "2017-07-24T13:38:53.730299Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model.fit(y_train, y_train, batch_size=2, epochs=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:22:40.293621",
     "start_time": "2017-07-24T13:22:40.234618Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.imshow(X_train[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:38:21.468453",
     "start_time": "2017-07-24T13:38:21.399449Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.imshow(y_train[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:40:09.671642",
     "start_time": "2017-07-24T13:40:09.624639Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "pred = model.predict(y_train[2].reshape((1, *img_shape)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:39:57.998975",
     "start_time": "2017-07-24T13:39:57.994974Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "pred.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:40:10.894712",
     "start_time": "2017-07-24T13:40:10.828708Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "plt.imshow(pred[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:34:20.973698",
     "start_time": "2017-07-24T13:34:20.951697Z"
    },
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "pred"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train (GAN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:04:04.223786",
     "start_time": "2017-07-24T13:04:04.217785Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from keras.models import Sequential\n",
    "from keras.layers.core import Activation, Dense\n",
    "from keras import backend as K\n",
    "from keras import optimizers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:45:00.283264",
     "start_time": "2017-07-24T13:45:00.278264Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "img_shape = (128, 128, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:46:12.787411",
     "start_time": "2017-07-24T13:46:10.559284Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "X_train = image_processing.load_data(dataset_info.apply(lambda x : join(x['dirpath'], x['ref']), axis=1).values, img_shape[:2])\n",
    "y_train = image_processing.load_data(dataset_info.apply(lambda x : join(x['dirpath'], x['lineart']), axis=1).values, img_shape[:2])\n",
    "print(X_train.shape)\n",
    "print(y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:46:14.340500",
     "start_time": "2017-07-24T13:46:14.316499Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generator(img_shape):\n",
    "    model = Sequential()\n",
    "    model.add(Convolution2D(32, (5, 5), strides=(2, 2), padding='same', input_shape=img_shape,\n",
    "             activation='relu'))\n",
    "    model.add(Convolution2D(64, (3, 3), padding='same',\n",
    "             activation='relu'))\n",
    "\n",
    "    model.add(Convolution2D(128, (5, 5), strides=(2, 2), padding='same',\n",
    "             activation='relu'))\n",
    "\n",
    "    model.add(Convolution2D(256, (3, 3), padding='same',\n",
    "             activation='relu'))\n",
    "    model.add(Convolution2D(256, (3, 3), padding='same',\n",
    "             activation='relu'))\n",
    "\n",
    "    model.add(Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same',\n",
    "             activation='relu'))\n",
    "    model.add(Convolution2D(64, (3, 3), padding='same',\n",
    "             activation='relu'))\n",
    "\n",
    "    model.add(Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same',\n",
    "             activation='relu'))\n",
    "\n",
    "    model.add(Convolution2D(3, (5, 5), padding='same',\n",
    "             activation='sigmoid'))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:46:17.503681",
     "start_time": "2017-07-24T13:46:17.443678Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def discriminator(img_shape):\n",
    "    model = Sequential()\n",
    "    model.add(Convolution2D(32, (5, 5), strides=(2, 2), padding='same', input_shape=img_shape,\n",
    "             activation='relu'))\n",
    "    model.add(Convolution2D(64, (3, 3), padding='same',\n",
    "             activation='relu'))\n",
    "\n",
    "    model.add(Convolution2D(128, (5, 5), strides=(2, 2), padding='same',\n",
    "             activation='relu'))\n",
    "\n",
    "    model.add(Convolution2D(256, (3, 3), padding='same',\n",
    "             activation='relu'))\n",
    "    model.add(Convolution2D(256, (3, 3), padding='same',\n",
    "             activation='relu'))\n",
    "\n",
    "    model.add(Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same',\n",
    "             activation='relu'))\n",
    "    model.add(Convolution2D(64, (3, 3), padding='same',\n",
    "             activation='relu'))\n",
    "\n",
    "    model.add(Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same',\n",
    "             activation='relu'))\n",
    "\n",
    "    model.add(Flatten())\n",
    "    model.add(Dense(512, activation='relu'))\n",
    "    #model.add(Dropout(0.5))\n",
    "    model.add(Dense(1, activation=K.sigmoid))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:46:27.113231",
     "start_time": "2017-07-24T13:46:26.698207Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# init GAN components\n",
    "d = discriminator(img_shape)\n",
    "g = generator(img_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:47:05.826445",
     "start_time": "2017-07-24T13:47:05.769442Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# discriminator model\n",
    "optimizer = optimizers.RMSprop(lr=0.0008, clipvalue=1.0, decay=6e-8)\n",
    "discriminator_model = d\n",
    "discriminator_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-07-24T13:47:06.667493",
     "start_time": "2017-07-24T13:47:06.500483Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# adversarial model\n",
    "optimizer = optimizers.RMSprop(lr=0.0004, clipvalue=1.0, decay=3e-8)\n",
    "adversarial_model = Sequential()\n",
    "adversarial_model.add(g)\n",
    "adversarial_model.add(d)\n",
    "adversarial_model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "epochs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for step in range(epochs):\n",
    "    \n",
    "    # generate data\n",
    "    # first we sample from the true distribution, then we generate some\n",
    "    # \"fake\" data by feeding noise to the generator\n",
    "    true_sample = np.reshape(gaussian_d.sample(batch_size), (batch_size, 1))\n",
    "    noise = generator_d.sample(batch_size)\n",
    "    fake_sample = g.predict(noise)\n",
    "    #pdb.set_trace()\n",
    "    \n",
    "    # train discriminator\n",
    "    # feed true and fake samples with respective labels (1, 0) to the discriminator\n",
    "    x = np.reshape(np.concatenate((true_sample, fake_sample)), (batch_size*2, 1))\n",
    "    y = np.ones([batch_size*2, 1])\n",
    "    y[batch_size:, :] = 0\n",
    "    d_loss = discriminator_model.train_on_batch(x, y)\n",
    "    \n",
    "    # train GAN\n",
    "    # feed noise to the model and expect true (1) response from discriminator,\n",
    "    # which is in turn fed with data generated by the generator\n",
    "    noise = np.reshape(generator_d.sample(batch_size), (batch_size, 1))\n",
    "    y = np.ones([batch_size, 1])\n",
    "    a_loss = adversarial_model.train_on_batch(noise, y)\n",
    "    \n",
    "    log_mesg = \"%d: [D loss: %f, acc: %f]\" % (step, d_loss[0], d_loss[1])\n",
    "    log_mesg = \"%s  [A loss: %f, acc: %f]\" % (log_mesg, a_loss[0], a_loss[1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:tensorflow-gpu]",
   "language": "python",
   "name": "conda-env-tensorflow-gpu-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "30px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": 4,
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
